{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Object Detection using YOLOv7"
      ],
      "metadata": {
        "id": "HbJG-0vkfOWt"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install -q tf-models-official"
      ],
      "metadata": {
        "id": "uo0ZQ1wBBSHf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Import Necessary Libraries"
      ],
      "metadata": {
        "id": "SU3iw-I5fTll"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import logging\n",
        "import yaml\n",
        "import tempfile\n",
        "import pprint\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "from PIL import Image\n",
        "from six import BytesIO\n",
        "from IPython import display\n",
        "from urllib.request import urlopen\n",
        "\n",
        "logging.disable(logging.WARNING)"
      ],
      "metadata": {
        "id": "z_yMmIayFgKJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import orbit\n",
        "import tensorflow as tf\n",
        "import tensorflow_models as tfm\n",
        "\n",
        "from official.core import exp_factory\n",
        "from official.core import config_definitions as cfg\n",
        "from official.projects.yolo.common import registry_imports\n",
        "from official.projects.yolo.configs import yolov7\n",
        "from official.vision.serving import export_saved_model_lib\n",
        "from official.vision.ops.preprocess_ops import normalize_image\n",
        "from official.vision.ops.preprocess_ops import resize_and_crop_image\n",
        "from official.vision.utils.object_detection import visualization_utils\n",
        "from official.vision.dataloaders.tf_example_decoder import TfExampleDecoder\n",
        "\n",
        "\n",
        "pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation\n",
        "print(tf.__version__) # Check the version of tensorflow used\n",
        "\n",
        "%matplotlib inline"
      ],
      "metadata": {
        "id": "G34F8TBOF1Se"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Download BCCD(Blood Cells) dataset."
      ],
      "metadata": {
        "id": "MfhQ5C-bfV_v"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!curl -L 'https://public.roboflow.com/ds/ZpYLqHeT0W?key=ZXfZLRnhsc' > './BCCD.v1-bccd.coco.zip'\n",
        "!unzip -q -o './BCCD.v1-bccd.coco.zip' -d './BCC.v1-bccd.coco/'\n",
        "!rm './BCCD.v1-bccd.coco.zip'"
      ],
      "metadata": {
        "id": "DzXHG3CnGPJv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Convert COCO format dataset to tfrecords"
      ],
      "metadata": {
        "id": "Lgw_SGaVfcX3"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Training TFRecords"
      ],
      "metadata": {
        "id": "9IzGpt2JfdEj"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "TRAIN_DATA_DIR='./BCC.v1-bccd.coco/train'\n",
        "TRAIN_ANNOTATION_FILE_DIR='./BCC.v1-bccd.coco/train/_annotations.coco.json'\n",
        "OUTPUT_TFRECORD_TRAIN='./bccd_coco_tfrecords/train'\n",
        "\n",
        "# Need to provide\n",
        "  # 1. image_dir: where images are present\n",
        "  # 2. object_annotations_file: where annotations are listed in json format\n",
        "  # 3. output_file_prefix: where to write output convered TFRecords files\n",
        "!python -m official.vision.data.create_coco_tf_record --logtostderr \\\n",
        "  --image_dir={TRAIN_DATA_DIR} \\\n",
        "  --object_annotations_file={TRAIN_ANNOTATION_FILE_DIR} \\\n",
        "  --output_file_prefix={OUTPUT_TFRECORD_TRAIN} \\\n",
        "  --num_shards=1"
      ],
      "metadata": {
        "id": "ka8AuE9KGu9c"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Validation TFRecords"
      ],
      "metadata": {
        "id": "x8WGWESzff5c"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "VALID_DATA_DIR='./BCC.v1-bccd.coco/valid'\n",
        "VALID_ANNOTATION_FILE_DIR='./BCC.v1-bccd.coco/valid/_annotations.coco.json'\n",
        "OUTPUT_TFRECORD_VALID='./bccd_coco_tfrecords/valid'\n",
        "\n",
        "!python -m official.vision.data.create_coco_tf_record --logtostderr \\\n",
        "  --image_dir={VALID_DATA_DIR} \\\n",
        "  --object_annotations_file={VALID_ANNOTATION_FILE_DIR} \\\n",
        "  --output_file_prefix={OUTPUT_TFRECORD_VALID} \\\n",
        "  --num_shards=1"
      ],
      "metadata": {
        "id": "gqLhou6_HKo0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Test TFRecords"
      ],
      "metadata": {
        "id": "f2cnxT0XfjpY"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "TEST_DATA_DIR='./BCC.v1-bccd.coco/test'\n",
        "TEST_ANNOTATION_FILE_DIR='./BCC.v1-bccd.coco/test/_annotations.coco.json'\n",
        "OUTPUT_TFRECORD_TEST='./bccd_coco_tfrecords/test'\n",
        "\n",
        "!python -m official.vision.data.create_coco_tf_record --logtostderr \\\n",
        "  --image_dir=$TEST_DATA_DIR \\\n",
        "  --object_annotations_file=$TEST_ANNOTATION_FILE_DIR \\\n",
        "  --output_file_prefix=$OUTPUT_TFRECORD_TEST \\\n",
        "  --num_shards=1"
      ],
      "metadata": {
        "id": "0UHWrMtnHY-Q"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Configure the YOLOv7"
      ],
      "metadata": {
        "id": "fxNVx9kkfmkZ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "train_data_input_path = './bccd_coco_tfrecords/train-00000-of-00001.tfrecord'\n",
        "valid_data_input_path = './bccd_coco_tfrecords/valid-00000-of-00001.tfrecord'\n",
        "test_data_input_path = './bccd_coco_tfrecords/test-00000-of-00001.tfrecord'\n",
        "model_dir = './trained_model/'\n",
        "export_dir ='./exported_model/'"
      ],
      "metadata": {
        "id": "9bRtYTviIO9s"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Download YOLOv7 config file"
      ],
      "metadata": {
        "id": "3C99W7mJfr-Q"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!wget https://raw.githubusercontent.com/tensorflow/models/master/official/projects/yolo/configs/experiments/yolov7/detection/yolov7_gpu.yaml"
      ],
      "metadata": {
        "id": "RAoNvLSLPlwq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "In Model Garden, the collections of parameters that define a model are called configs. Model Garden can create a config based on a known set of parameters via a factory.\n",
        "\n",
        "Use the yolo_darknet experiment configuration, all the configurations are present [here]().\n",
        "\n",
        "The configuration defines an experiment to train a RetinanNet with Resnet-50 as backbone, FPN as decoder. Default Configuration is trained on COCO train2017 and evaluated on COCO val2017.\n",
        "\n",
        "There are also other alternative experiments available such as retinanet_resnetfpn_coco, retinanet_spinenet_coco, fasterrcnn_resnetfpn_coco and more. One can switch to them by changing the experiment name argument to the get_exp_config function.\n",
        "\n",
        "We are going to fine tune the Resnet-50 backbone checkpoint which is already present in the default configuration.\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "M-DgdoJWfv46"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "exp_config = exp_factory.get_exp_config('yolo_darknet')"
      ],
      "metadata": {
        "id": "nHz6-t-aIXz4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "with open('./yolov7_gpu.yaml') as f:\n",
        "  params = yaml.full_load(f)\n",
        "exp_config.override(params, is_strict=False)"
      ],
      "metadata": {
        "id": "xAzD4cZSPpOI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Adjust task configuration which includes model, train_data and validation_data."
      ],
      "metadata": {
        "id": "ef_ZA-pbgFAa"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "batch_size = 8\n",
        "num_classes = 3\n",
        "\n",
        "HEIGHT, WIDTH = 416, 416\n",
        "IMG_SIZE = [HEIGHT, WIDTH, 3]\n",
        "\n",
        "\n",
        "# Backbone config.\n",
        "exp_config.task.init_checkpoint = ''\n",
        "exp_config.task.freeze_backbone = False\n",
        "exp_config.task.annotation_file = ''\n",
        "\n",
        "# Model config.\n",
        "exp_config.task.model.input_size = IMG_SIZE\n",
        "exp_config.task.model.num_classes = num_classes\n",
        "\n",
        "# Training data config.\n",
        "exp_config.task.train_data.input_path = train_data_input_path\n",
        "exp_config.task.train_data.global_batch_size = batch_size\n",
        "exp_config.task.train_data.parser.aug_scale_max = 1.0\n",
        "exp_config.task.train_data.parser.aug_scale_min = 1.0\n",
        "\n",
        "# Validation data config.\n",
        "exp_config.task.validation_data.input_path = valid_data_input_path\n",
        "exp_config.task.validation_data.global_batch_size = batch_size"
      ],
      "metadata": {
        "id": "KgDas9WmI_u2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Adjust trainer configuration"
      ],
      "metadata": {
        "id": "OwIapj-ngGb4"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "train_steps = 2000\n",
        "\n",
        "exp_config.trainer.steps_per_loop = 500 # steps_per_loop = num_of_training_examples // train_batch_size\n",
        "exp_config.trainer.summary_interval = 500\n",
        "exp_config.trainer.checkpoint_interval = 500\n",
        "exp_config.trainer.validation_interval = 500\n",
        "exp_config.trainer.validation_steps =  500 # validation_steps = num_of_validation_examples // eval_batch_size\n",
        "exp_config.trainer.train_steps = train_steps\n",
        "exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 500\n",
        "exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
        "exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
        "exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 3.2e-5"
      ],
      "metadata": {
        "id": "sOOtAun_JJCh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Checkout the changed the configuration parameters and default parameters for further customization of model."
      ],
      "metadata": {
        "id": "0ZK_AfAngJpk"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "pp.pprint(exp_config.as_dict())"
      ],
      "metadata": {
        "id": "VfrJhSrfIvqu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Setup the distribution strategy"
      ],
      "metadata": {
        "id": "kGEcZTbHgNUm"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
        "\n",
        "if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
        "    tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
        "\n",
        "if 'GPU' in ''.join(logical_device_names):\n",
        "  distribution_strategy = tf.distribute.MirroredStrategy()\n",
        "elif 'TPU' in ''.join(logical_device_names):\n",
        "  tf.tpu.experimental.initialize_tpu_system()\n",
        "  tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
        "  distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
        "else:\n",
        "  print('Warning: this will be really slow.')\n",
        "  distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])"
      ],
      "metadata": {
        "id": "u5zfMsq_JsT7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Create the `Task` object `(tfm.core.base_task.Task)` from the `config_definitions.TaskConfig`.\n",
        "\n",
        "The Task object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by `tfm.core.train_lib.run_experimen`t."
      ],
      "metadata": {
        "id": "RR4uiiAzg3kh"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "with distribution_strategy.scope():\n",
        "  task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)"
      ],
      "metadata": {
        "id": "6H_pcsn3JoXt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Build inputs from given training tfrecords"
      ],
      "metadata": {
        "id": "MnFUsZsJg6QY"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
        "  print()\n",
        "  print(f'images.shape: {str(images.shape):16}  images.dtype: {images.dtype!r}')\n",
        "  print(f'labels.keys: {labels.keys()}')"
      ],
      "metadata": {
        "id": "xXE_sjHQJvwW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Create category index for each label"
      ],
      "metadata": {
        "id": "Ce_0Prspg7Av"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "tf_ex_decoder = TfExampleDecoder() # define tf example decoder\n",
        "\n",
        "category_index={\n",
        "    1: {\n",
        "        'id': 1,\n",
        "        'name': 'Platelets'\n",
        "       },\n",
        "    2: {\n",
        "        'id': 2,\n",
        "        'name': 'RBC'\n",
        "       },\n",
        "    3: {\n",
        "        'id': 3,\n",
        "        'name': 'WBC'\n",
        "       }\n",
        "}"
      ],
      "metadata": {
        "id": "O4pypN0jg9aM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Helper function for visualizing the results from TFRecords."
      ],
      "metadata": {
        "id": "-IPcyGWgg_Tg"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def show_batch(raw_records, num_of_examples):\n",
        "  plt.figure(figsize=(20, 20))\n",
        "  use_normalized_coordinates=True\n",
        "  min_score_thresh = 0.30\n",
        "  for i, serialized_example in enumerate(raw_records):\n",
        "    plt.subplot(1, num_of_examples, i + 1)\n",
        "    decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
        "    image = decoded_tensors['image'].numpy().astype('uint8')\n",
        "    scores = np.ones(shape=(len(decoded_tensors['groundtruth_boxes'])))\n",
        "    visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
        "        image,\n",
        "        decoded_tensors['groundtruth_boxes'].numpy(),\n",
        "        decoded_tensors['groundtruth_classes'].numpy().astype('int'),\n",
        "        scores,\n",
        "        category_index=category_index,\n",
        "        use_normalized_coordinates=use_normalized_coordinates,\n",
        "        max_boxes_to_draw=200,\n",
        "        min_score_thresh=min_score_thresh,\n",
        "        agnostic_mode=False,\n",
        "        instance_masks=None,\n",
        "        line_thickness=4)\n",
        "\n",
        "    plt.imshow(image)\n",
        "    plt.axis('off')\n",
        "    plt.title(f'Image-{i+1}')\n",
        "  plt.show()"
      ],
      "metadata": {
        "id": "sn3dOH61hDI4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Visualize the training data samples"
      ],
      "metadata": {
        "id": "an1PwnGShEdt"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "buffer_size = 20\n",
        "num_of_examples = 3\n",
        "\n",
        "raw_records = tf.data.TFRecordDataset(\n",
        "    exp_config.task.train_data.input_path).shuffle(\n",
        "        buffer_size=buffer_size).take(num_of_examples)\n",
        "show_batch(raw_records, num_of_examples)"
      ],
      "metadata": {
        "id": "2Moh4jfqhGQU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Train and Evaluate the model using `tfm.core.train_lib.run_experiment`."
      ],
      "metadata": {
        "id": "T_DzGO6zhIB6"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model, eval_logs = tfm.core.train_lib.run_experiment(\n",
        "    distribution_strategy=distribution_strategy,\n",
        "    task=task,\n",
        "    mode='train_and_eval',\n",
        "    params=exp_config,\n",
        "    model_dir=model_dir,\n",
        "    run_post_eval=True)"
      ],
      "metadata": {
        "id": "Gk2kn8qyJ5Yk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Save the trained experiment configuration"
      ],
      "metadata": {
        "id": "5Um3JlvXhKTW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# save config file\n",
        "tfm.core.train_utils.serialize_config(exp_config, model_dir)"
      ],
      "metadata": {
        "id": "pt1w5i2qhMd2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Export the trained model"
      ],
      "metadata": {
        "id": "-iOdrkPwhOs1"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "EXPORT_DIR_PATH = \"./exported_model/\"\n",
        "CHECKPOINT_PATH = \"/content/trained_model/ckpt-1000\"\n",
        "CONFIG_FILE_PATH = \"/content/trained_model/params.yaml\""
      ],
      "metadata": {
        "id": "QNK411CBhRF-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!python -m official.projects.yolo.serving.export_saved_model --export_dir={EXPORT_DIR_PATH}/ \\\n",
        "                   --checkpoint_path={CHECKPOINT_PATH} \\\n",
        "                   --config_file={CONFIG_FILE_PATH} \\\n",
        "                   --batch_size=1 \\\n",
        "                   --input_image_size={HEIGHT},{WIDTH}"
      ],
      "metadata": {
        "id": "l1PsfVKUhSnK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Load the exported model for inference"
      ],
      "metadata": {
        "id": "aRAJM37zhWQU"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "imported = tf.saved_model.load(\"/content/exported_model/saved_model\")\n",
        "model_fn = imported.signatures['serving_default']"
      ],
      "metadata": {
        "id": "tUnJI6b1hW63"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Helper functions for inference"
      ],
      "metadata": {
        "id": "Y2EpOLg6hYb0"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def load_image_into_numpy_array(path):\n",
        "  \"\"\"Load an image from file into a numpy array.\n",
        "\n",
        "  Puts image into numpy array to feed into tensorflow graph.\n",
        "  Note that by convention we put it into a numpy array with shape\n",
        "  (height, width, channels), where channels=3 for RGB.\n",
        "\n",
        "  Args:\n",
        "    path: the file path to the image\n",
        "\n",
        "  Returns:\n",
        "    uint8 numpy array with shape (img_height, img_width, 3)\n",
        "  \"\"\"\n",
        "  image = None\n",
        "  if(path.startswith('http')):\n",
        "    response = urlopen(path)\n",
        "    image_data = response.read()\n",
        "    image_data = BytesIO(image_data)\n",
        "    image = Image.open(image_data)\n",
        "  else:\n",
        "    image_data = tf.io.gfile.GFile(path, 'rb').read()\n",
        "    image = Image.open(BytesIO(image_data))\n",
        "\n",
        "  (im_width, im_height) = image.size\n",
        "  return np.array(image.getdata()).reshape(\n",
        "      (1, im_height, im_width, 3)).astype(np.uint8)\n",
        "\n",
        "\n",
        "\n",
        "def build_inputs_for_object_detection(image, input_image_size):\n",
        "  \"\"\"Builds Object Detection model inputs for serving.\"\"\"\n",
        "  image, _ = resize_and_crop_image(\n",
        "      image,\n",
        "      input_image_size,\n",
        "      padded_size=input_image_size,\n",
        "      aug_scale_min=1.0,\n",
        "      aug_scale_max=1.0)\n",
        "  return image"
      ],
      "metadata": {
        "id": "Ql177vJ3hUXj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Visualize original test data"
      ],
      "metadata": {
        "id": "pkyLaTNkhbz0"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "num_of_examples = 3\n",
        "\n",
        "test_ds = tf.data.TFRecordDataset(\n",
        "    '/content/bccd_coco_tfrecords/test-00000-of-00001.tfrecord').take(\n",
        "        num_of_examples)\n",
        "show_batch(test_ds, num_of_examples)"
      ],
      "metadata": {
        "id": "nWqc5lKiheQa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Inference on test data"
      ],
      "metadata": {
        "id": "iWynYjEOhhv_"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "input_image_size = (HEIGHT, WIDTH)\n",
        "plt.figure(figsize=(20, 20))\n",
        "min_score_thresh = 0.3 # Change minimum score for threshold to see all bounding boxes confidences.\n",
        "\n",
        "for i, serialized_example in enumerate(test_ds):\n",
        "  plt.subplot(1, 3, i+1)\n",
        "  decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
        "  image = build_inputs_for_object_detection(decoded_tensors['image'], input_image_size)\n",
        "  image = tf.expand_dims(image, axis=0)\n",
        "  image = tf.cast(image, dtype = tf.uint8)\n",
        "  image_np = image[0].numpy()\n",
        "  result = model_fn(image)\n",
        "  visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
        "      image_np,\n",
        "      result['detection_boxes'][0].numpy(),\n",
        "      result['detection_classes'][0].numpy().astype(int),\n",
        "      result['detection_scores'][0].numpy(),\n",
        "      category_index=category_index,\n",
        "      use_normalized_coordinates=True,\n",
        "      max_boxes_to_draw=200,\n",
        "      min_score_thresh=min_score_thresh,\n",
        "      agnostic_mode=False,\n",
        "      instance_masks=None,\n",
        "      line_thickness=4)\n",
        "  plt.imshow(image_np)\n",
        "  plt.axis('off')\n",
        "\n",
        "plt.show()"
      ],
      "metadata": {
        "id": "9Gvvk1G1hgCD"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}